import os
import csv
import json
from xlrd import open_workbook

def to_float(str_in):
	try:
		return(float(str_in))
	except:
		return None

#######################################
# Replication materials for "Messy Data, Robust Inference? Navigating Obstancles to Inference with bigKRLS"
# By: Pete Mohanty (pmohanty@stanford.edu) and Robert Shaffer (rbshaffer@utexas.edu)
####################################
# change to reflect your own paths #
####################################

dataset_base = '/home/rbshaffer/Dropbox/pa_replication/2016_election_application/2016_election_data/datasets'
out_path = '/home/rbshaffer/Dropbox/pa_replication/2016_election_application/2016_election_data/~/2016_election_application/2016_election_data.csv'

# list of variables to extract
vars_to_extract = ['Rural-urban_Continuum_Code_2013',
		   'POVALL_2015',
		   'Unemployment_rate_2015',
		   'Median_Household_Income_2015',
		   'Med_HH_Income_Percent_of_State_Total_2015',
		   'POP_ESTIMATE_2015',
		   'Percent of adults with less than a high school diploma, 2011-2015',
		   'Percent of adults with a high school diploma only, 2011-2015', 
		   'Percent of adults completing some college or associate\'s degree, 2011-2015', 
		   'Percent of adults with a bachelor\'s degree or higher, 2011-2015', 
		   'AGE050210D']

while type(vars_to_extract) is not list:
	vars_to_extract = input('Please re-enter your variable names as a list!\n')


dataset_paths = [os.path.join(dataset_base, p) for p in os.listdir(dataset_base)]

out = {}

# get election results
with open(os.path.join(dataset_base, 'election_results.csv'), 'r') as f:
	results = csv.reader(f)
	results.next()

	for row in results:
		out[row[1]] = {'dem_2016': to_float(row[2]), 'gop_2016': to_float(row[3]),
				'dem_2012': to_float(row[13]), 'gop_2012': to_float(row[14]),
				'total_votes': float(row[4]), 'county_name': row[10], 
				'state': row[9], 'white_mortality': None,
				'latino_mortality': None, 'black_mortality': None,
				'asian_mortality': None, 'na_mortality': None,
				'all_mortality_2009-2011': None, 
				'all_mortality_2003-2015': None, 
				'white_2009-2011_despair_mortality': None,
				'all_2009-2011_despair_mortality': None,
				'white_2013-2015_despair_mortality': None,
				'all_2013-2015_despair_mortality': None,
				'all_1999-2015_despair_mortality': None,
				'white_population': None, 
				'latino_population': None, 'black_population': None,
				'asian_population': None, 'na_population': None,
				'total_population': None, 'lat': None, 'lon': None}
		out[row[1]].update({v: None for v in vars_to_extract})

# get geolocations
with open(os.path.join(dataset_base, 'Gaz_counties_national.csv'), 'r') as f:
	geo = csv.reader(f, delimiter='\t')
	geo.next()
	for row in geo:
		fips = row[1].lstrip('0')
		if fips in out:
			lat = row[10]
			lon = row[11]

			out[fips]['lat'] = lat
			out[fips]['lon'] = lon
		else:
			print fips

# get total mortality
mortality_files = [path for path in os.listdir(dataset_base) if 'mortality_2' in path and 'despair' not in path]
for path in mortality_files:
	print path
	if '2009' in path:
		date = '2009-2011'
	else:
		date = '2013-2015'

	subset = 'all_mortality_' + date

	with open(os.path.join(dataset_base, path), 'r') as f:
		mortality = csv.reader(f, delimiter = '\t')
		mortality.next()

		for row in mortality:
			if len(row) > 3 and row[2].lstrip('0') in out:
				fips = row[2].lstrip('0')
				out[fips][subset] = float(row[6].strip(' (Unreliable)'))


# get mortality for race categories
with open(os.path.join(dataset_base, 'mortality_categories.csv'), 'r') as f:
	mortality = csv.reader(f, delimiter = '\t')
	mortality.next()
	
	for row in mortality:
		if len(row) > 3 and row[2].lstrip('0') in out and row[10] != 'Not Applicable':
			fips = row[2].lstrip('0')

			if 'White' in row[5] and 'Not' in row[3]:
				race_var = 'white'
			elif 'White' in row[5] and 'Not' not in row[3]:
				race_var = 'latino'
			elif 'Black' in row[5]:
				race_var = 'black'
			elif 'Asian' in row[5]:
				race_var = 'asian'
			else:
				race_var = 'na'

			out[fips][race_var + '_mortality'] = float(row[10].strip(' (Unreliable)'))
			out[fips][race_var + '_population'] = float(row[8])

# get mortality for "deaths of despair" only
despair_files = [path for path in os.listdir(dataset_base) if 'despair' in path]
for path in despair_files:
	if 'white' in path:
		race = 'white'
	else:
		race = 'all'

	if '2009' in path:
		dates = '2009-2011'
	elif '2013' in path:
		dates = '2013-2015'
	else:
		dates = '1999-2015'
	
	subset = race + '_' + dates + '_' + 'despair_mortality'

	with open(os.path.join(dataset_base, path), 'r') as f:
		mortality = csv.reader(f, delimiter='\t')
		mortality.next()

		for row in mortality:
			if len(row) > 3 and row[2].lstrip('0') in out:
				fips = row[2].lstrip('0')
				out[fips][subset] = float(row[6].strip(' (Unreliable)'))

# open all databases, get requested variables

# get total mortality
with open(os.path.join(dataset_base, 'mortality_all.csv'), 'r') as f:
	mortality = csv.reader(f, delimiter='\t')
	mortality.next()	

	for row in mortality:
		if len(row) > 3 and row[2].lstrip('0') in out:
			fips = row[2].lstrip('0')
			out[fips]['total_mortality'] = float(row[6].strip(' (Unreliable)'))
			out[fips]['total_population'] = float(row[4])

# get mortality for race categories
with open(os.path.join(dataset_base, 'mortality_categories.csv'), 'r') as f:
	mortality = csv.reader(f, delimiter = '\t')
	mortality.next()
	
	for row in mortality:
		if len(row) > 3 and row[2].lstrip('0') in out and row[10] != 'Not Applicable':
			fips = row[2].lstrip('0')

			if 'White' in row[5] and 'Not' in row[3]:
				race_var = 'white'
			elif 'White' in row[5] and 'Not' not in row[3]:
				race_var = 'latino'
			elif 'Black' in row[5]:
				race_var = 'black'
			elif 'Asian' in row[5]:
				race_var = 'asian'
			else:
				race_var = 'na'

			out[fips][race_var + '_mortality'] = float(row[10].strip(' (Unreliable)'))
			out[fips][race_var + '_population'] = float(row[8])

# get mortality for "deaths of despair" only
mortality_files = [path for path in os.listdir(dataset_base) if 'despair' in path]
for path in mortality_files:
	if 'white' in path:
		race = 'white'
	else:
		race = 'all'

	if '2009' in path:
		dates = '2009-2011'
	elif '2013' in path:
		dates = '2013-2015'
	else:
		dates = '1999-2015'
	
	subset = race + '_' + dates + '_' + 'despair_mortality'

	with open(os.path.join(dataset_base, path), 'r') as f:
		mortality = csv.reader(f, delimiter='\t')
		mortality.next()

		for row in mortality:
			if len(row) > 3 and row[2].lstrip('0') in out:
				fips = row[2].lstrip('0')
				out[fips][subset] = float(row[6].strip(' (Unreliable)'))

# open all databases, get requested variables
xls_paths = [path for path in dataset_paths if 'xls' in path]
for path in xls_paths:
	book = open_workbook(path)
	for name in book.sheet_names():
		sheet = book.sheet_by_name(name)
		varnames = [c.value for c in sheet.row(0)]

		if 'STCOU' in varnames:
			fips_ind = varnames.index('STCOU')
		elif 'FIPS Code' in varnames:
			fips_ind = varnames.index('FIPS Code')
		elif 'FIPS' in varnames:
			fips_ind = varnames.index('FIPS')
		elif 'FIPStxt' in varnames:
			fips_ind = varnames.index('FIPStxt')
		else:
			fips_ind = None

		var_inds = {}

		for var in vars_to_extract:
			try:
				i = varnames.index(var)
				var_inds[i] = var
			except:
				pass

		if fips_ind is not None and var_inds:
			nrows = sheet.nrows
			for i in range(1, nrows):
				row = sheet.row(i)
				fips = row[fips_ind].value.lstrip('0')
				
				for v in var_inds:
					vname = var_inds[v]
					val = row[v].value
					if fips in out:
						out[fips][vname] = val
					else:
						print fips		

# dump the output
with open(out_path, 'w') as f:
	fieldnames = ['fips'] + out[out.keys()[0]].keys()
	writer = csv.DictWriter(f, fieldnames = fieldnames)

	writer.writeheader()

	for key, value in out.iteritems():
		value.update({'fips': key})
		writer.writerow(value)


